library(parallel)
library(foreach)

inter.kernel.spillover = function (Y, D, X, Z = NULL, weights = NULL, FE = NULL, data, 
          na.rm = FALSE, CI = TRUE, conf.level = 0.95, cl = NULL, 
          neval = 50, nboots = 200, parallel = TRUE, cores = 4, seed = 2139, 
          bw = NULL, grid = 20, metric = "MSPE", Ylabel = NULL, Dlabel = NULL, 
          Xlabel = NULL, main = NULL, xlim = NULL, ylim = NULL, Xdistr = "histogram", 
          file = NULL) 
{
  x <- NULL
  y <- NULL
  xmin <- NULL
  xmax <- NULL
  ymin <- NULL
  ymax <- NULL
  ME <- NULL
  CI_lower <- NULL
  CI_upper <- NULL
  if (is.character(Y) == FALSE) {
    stop("Y is not a string.")
  }  else {
    Y <- Y[1]
  }
  if (is.character(D) == FALSE) {
    stop("D is not a string.")
  }  else {
    D <- D[1]
  }
  if (is.character(X) == FALSE) {
    stop("X is not a string.")
  }  else {
    X <- X[1]
  }
  if (is.null(Z) == FALSE) {
    for (i in 1:length(Z)) {
      if (is.character(Z[i]) == FALSE) {
        stop("Some element in Z is not a string.")
      }
    }
  }
  if (is.null(FE) == FALSE) {
    for (i in 1:length(FE)) {
      if (is.character(FE[i]) == FALSE) {
        stop("Some element in FE is not a string.")
      }
    }
  }
  if (is.null(weights) == FALSE) {
    if (is.character(weights) == FALSE) {
      stop("weigths is not a string.")
    }    else {
      weights <- weights[1]
    }
  }
  if (is.data.frame(data) == FALSE) {
    stop("Not a data frame.")
  }
  if (is.logical(na.rm) == FALSE & is.numeric(na.rm) == FALSE) {
    stop("na.rm is not a logical flag.")
  }
  if (is.logical(CI) == FALSE & is.numeric(CI) == FALSE) {
    stop("CI is not a logical flag.")
  }
  if (is.null(conf.level) == FALSE) {
    if (is.numeric(conf.level) == FALSE) {
      stop("conf.level should be a number between 0.5 and 1.")
    }    else {
      if (conf.level <= 0.5 | conf.level > 1) {
        stop("conf.level should be a number between 0.5 and 1.")
      }
    }
  }
  if (is.null(cl) == FALSE) {
    if (is.character(cl) == FALSE) {
      stop("cl is not a string.")
    }    else {
      cl <- cl[1]
    }
  }
  if (is.null(neval) == FALSE) {
    if (is.numeric(neval) == FALSE) {
      stop("neval is not a positive integer.")
    }    else {
      neval <- neval[1]
      if (neval%%1 != 0 | neval <= 0) {
        stop("neval is not a positive integer.")
      }
    }
  }
  if (is.null(nboots) == FALSE) {
    if (is.numeric(nboots) == FALSE) {
      stop("nboots is not a positive integer.")
    }    else {
      nboots <- nboots[1]
      if (nboots%%1 != 0 | nboots < 1) {
        stop("nboots is not a positive number.")
      }
    }
  }
  if (is.logical(parallel) == FALSE & is.numeric(parallel) == 
      FALSE) {
    stop("paralell is not a logical flag.")
  }
  if (is.numeric(cores) == FALSE) {
    stop("cores is not a positive integer.")
  }  else {
    cores <- cores[1]
    if (cores%%1 != 0 | cores <= 0) {
      stop("cores is not a positive integer.")
    }
  }
  if (is.null(bw) == FALSE) {
    if (is.numeric(bw) == FALSE) {
      stop("bw should be a positive number.")
    }    else {
      bw <- bw[1]
    }
    if (bw <= 0) {
      stop("bw should be a positive number.")
    }
  }
  if (is.numeric(seed) == FALSE) {
    stop("seed should be a number.")
  }
  if (is.numeric(grid) == FALSE) {
    stop("grid should be numeric.")
  }  else {
    if (length(grid) == 1) {
      if (grid%%1 != 0 | grid < 1) {
        stop("grid is not a positive integer.")
      }
    }    else {
      grid <- grid[which(grid > 0)]
    }
  }
  if (!metric %in% c("MSPE", "MAPE")) {
    stop("metric should be either \"MSPE\" or \"MAPE\".")
  }
  if (is.null(Ylabel) == TRUE) {
    Ylabel <- Y
  }  else {
    if (is.character(Ylabel) == FALSE) {
      stop("Ylabel is not a string.")
    }    else {
      Ylabel <- Ylabel[1]
    }
  }
  if (is.null(Dlabel) == TRUE) {
    Dlabel <- D
  }  else {
    if (is.character(Dlabel) == FALSE) {
      stop("Dlabel is not a string.")
    }    else {
      Dlabel <- Dlabel[1]
    }
  }
  if (is.null(Xlabel) == TRUE) {
    Xlabel <- X
  }  else {
    if (is.character(Xlabel) == FALSE) {
      stop("Xlabel is not a string.")
    }    else {
      Xlabel <- Xlabel[1]
    }
  }
  if (is.null(main) == FALSE) {
    main <- main[1]
  }
  if (!Xdistr %in% c("hist", "histogram", "density")) {
    stop("Xdistr must be either \"histogram\" or \"density\".")
  }
  if (is.null(xlim) == FALSE) {
    if (is.numeric(xlim) == FALSE) {
      stop("Some element in xlim is not numeric.")
    }    else {
      if (length(xlim) != 2) {
        stop("xlim must be of length 2.")
      }
    }
  }
  if (is.null(ylim) == FALSE) {
    if (is.numeric(ylim) == FALSE) {
      stop("Some element in ylim is not numeric.")
    }    else {
      if (length(ylim) != 2) {
        stop("ylim must be of length 2.")
      }
    }
  }
  if (is.null(seed) == FALSE) {
    set.seed(seed)
  }
  M <- c(Y, D, X, Z, FE, cl, weights)
  for (var in M) {
    if ((var %in% names(data)) == FALSE) 
      stop("Wrong variable name.")
  }
  if (na.rm == TRUE) {
    data <- na.omit(data[, c(Y, D, X, Z, FE)])
  }  else {
    if (sum(is.na(data[, c(Y, D, X, Z, FE)])) > 0) {
      stop("Missing values. Try option na.rm = TRUE\n")
    }
  }
  n <- dim(data)[1]
  if (is.null(cl) == TRUE & is.null(FE) == FALSE) {
    warnings("Fixed effects model assumed. Clustering standard errors highly recommended.")
  }
  if (is.null(FE) == FALSE) {
    if (length(FE) == 1) {
      data[, FE] <- as.numeric(as.factor(data[, FE]))
    }    else {
      data[, FE] <- sapply(data[, FE], function(vec) {
        as.numeric(as.factor(vec))
      })
    }
  }
  if (is.null(Xlabel) == TRUE) {
    Xlabel = X
  }
  if (is.null(Ylabel) == TRUE) {
    Ylabel = Y
  }
  if (is.null(Dlabel) == TRUE) {
    Dlabel = D
  }
  if (length(unique(data[, X])) < 5) {
    warning("Moderator has less than 5 values; consider a fully saturated model.")
  }
  if (parallel == TRUE & (CI == TRUE | is.null(bw))) {
    requireNamespace("doParallel")
    maxcores <- detectCores()
    cores <- min(maxcores, cores)
    pcl <- makeCluster(cores)
    doParallel::registerDoParallel(pcl)
    cat("Parallel computing with", cores, "cores...\n")
  }
  X.eval <- seq(min(data[, X]), max(data[, X]), length.out = neval)
  if (is.null(bw) == TRUE) {
    CV <- 1
    if (length(grid) == 1) {
      rangeX <- max(data[, X]) - min(data[, X])
      grid <- exp(seq(log(rangeX/50), log(rangeX), length.out = grid))
    }
    cv.out <- crossvalidate(data = data, X.eval = X.eval, 
                            Y = Y, D = D, X = X, Z = Z, FE = FE, cl = cl, weights = weights, 
                            grid = grid, metric = metric, parallel = parallel)
    bw <- cv.out$opt.bw
  }  else {
    CV <- 0
  }
  if (CI == FALSE) {
    est <- coefs(data = data, bw = bw, Y = Y, X = X, D = D, 
                 Z = Z, FE = FE, X.eval = X.eval, weights = weights)[,c('X','x')]
  }  else {
    coef <- coefs(data = data, bw = bw, Y = Y, X = X, D = D, 
                  Z = Z, FE = FE, X.eval = X.eval, weights = weights)[,c('X','x')]
    if (is.null(cl) == FALSE) {
      clusters <- unique(data[, cl])
      id.list <- split(1:n, data[, cl])
    }
    oneboot <- function() {
      if (is.null(cl) == TRUE) {
        smp <- sample(1:n, n, replace = TRUE)
      }      else {
        cluster.boot <- sample(clusters, length(clusters), 
                               replace = TRUE)
        smp <- unlist(id.list[match(cluster.boot, clusters)])
      }
      s <- data[smp, ]
      out <- coefs(data = s, bw = bw, Y = Y, X = X, D = D, 
                   Z = Z, FE = FE, X.eval = X.eval, weights = weights)[,c('X','x')]
      return(out)
    }
    cat("Bootstrapping ...")
    if (parallel == TRUE) {
      suppressWarnings(bootout <- foreach(i = 1:nboots, 
                                          .combine = cbind, .export = c("oneboot", "coefs"), 
                                          .inorder = FALSE) %dopar% {
                                            oneboot()
                                          })
      cat("\r")
    }    else {
      bootout <- matrix(NA, length(X.eval), nboots)
      for (i in 1:nboots) {
        bootout[, i] <- oneboot()
        if (i%%50 == 0) 
          cat(i)
        else cat(".")
      }
      cat("\r")
    }
    CI.lvl <- c((1 - conf.level)/2, (1 - (1 - conf.level)/2))
    ci <- t(apply(bootout, 1, quantile, CI.lvl))
    est <- data.frame(cbind(X = X.eval, ME = coef, SE = apply(bootout, 
                                                              1, sd), CI_lower = ci[, 1], CI_upper = ci[, 2]))
  }
  if (parallel == TRUE & (CI == TRUE | CV == 1)) {
    suppressWarnings(stopCluster(pcl))
    cat("\n")
  }
  requireNamespace("ggplot2")
  if (is.null(Xlabel) == FALSE) {
    x.label <- c(paste("Moderator: ", Xlabel, sep = ""))
    y.label <- c(paste("Marginal Effect of ", Dlabel, " on ", 
                       Ylabel, sep = ""))
  }  else {
    x.label <- c(paste("Moderator: ", X, sep = ""))
    y.label <- c(paste("Marginal Effect of ", D, " on ", 
                       Y, sep = ""))
  }
  p1 <- ggplot() + geom_hline(yintercept = 0, colour = "white", 
                              size = 2)
  p1 <- p1 + geom_line(data = est, aes(x, ME))
  if (CI == TRUE) {
    p1 <- p1 + geom_ribbon(data = est, aes(x = X, ymin = CI_lower, 
                                           ymax = CI_upper), alpha = 0.2)
    yrange <- na.omit(c(est$CI_lower, est$CI_upper))
  }  else {
    yrange <- na.omit(c(est$ME))
  }
  p1 <- p1 + xlab(x.label) + ylab(y.label) + theme(axis.title = element_text(size = 15))
  if (is.null(ylim) == FALSE) {
    yrange <- c(ylim[2], ylim[1] + (ylim[2] - ylim[1]) * 
                  1/6)
  }
  maxdiff <- (max(yrange) - min(yrange))
  if (is.null(Xdistr) == TRUE) {
    Xdistr <- "density"
  }  else if (Xdistr != "density" & Xdistr != "histogram" & Xdistr != 
           "hist") {
    Xdistr <- "density"
  }
  if (Xdistr == "density") {
    if (length(unique(data[, D])) == 2) {
      de.co <- density(data[data[, D] == 0, X])
      de.tr <- density(data[data[, D] == 1, X])
      deX.ymin <- min(yrange) - maxdiff/5
      deX.co <- data.frame(x = de.co$x, y = de.co$y/max(de.co$y) * 
                             maxdiff/5 + min(yrange) - maxdiff/5)
      deX.tr <- data.frame(x = de.tr$x, y = de.tr$y/max(de.tr$y) * 
                             maxdiff/5 + min(yrange) - maxdiff/5)
      feed.col <- col2rgb("gray50")
      col.co <- rgb(feed.col[1]/1000, feed.col[2]/1000, 
                    feed.col[3]/1000)
      col.tr <- rgb(red = 1, blue = 0, green = 0)
      p1 <- p1 + geom_ribbon(data = deX.co, aes(x = x, 
                                                ymax = y, ymin = deX.ymin), fill = col.co, alpha = 0.2) + 
        geom_ribbon(data = deX.tr, aes(x = x, ymax = y, 
                                       ymin = deX.ymin), fill = col.tr, alpha = 0.2)
    }    else {
      de <- density(data[, X])
      deX.ymin <- min(yrange) - maxdiff/5
      deX <- data.frame(x = de$x, y = de$y/max(de$y) * 
                          maxdiff/5 + min(yrange) - maxdiff/5)
      feed.col <- col2rgb("gray50")
      col <- rgb(feed.col[1]/1000, feed.col[2]/1000, feed.col[3]/1000)
      p1 <- p1 + geom_ribbon(data = deX, aes(x = x, ymax = y, 
                                             ymin = deX.ymin), fill = col, alpha = 0.2)
    }
  }  else {
    if (length(unique(data[, D])) == 2) {
      hist.out <- hist(data[, X], breaks = 80, plot = FALSE)
      n.hist <- length(hist.out$mids)
      dist <- hist.out$mids[2] - hist.out$mids[1]
      hist.max <- max(hist.out$counts)
      count1 <- rep(0, n.hist)
      treat <- which(data[, D] == max(data[, D]))
      for (i in 1:n.hist) {
        count1[i] <- sum(data[treat, X] >= hist.out$breaks[i] & 
                           data[treat, X] < hist.out$breaks[(i + 1)])
      }
      count1[n.hist] <- sum(data[treat, X] >= hist.out$breaks[n.hist] & 
                              data[treat, X] <= hist.out$breaks[n.hist + 1])
      histX <- data.frame(ymin = rep(min(yrange) - maxdiff/5, 
                                     n.hist), ymax = hist.out$counts/hist.max * maxdiff/5 + 
                            min(yrange) - maxdiff/5, xmin = hist.out$mids - 
                            dist/2, xmax = hist.out$mids + dist/2, count1 = count1/hist.max * 
                            maxdiff/5 + min(yrange) - maxdiff/5)
      p1 <- p1 + geom_rect(data = histX, aes(xmin = xmin, 
                                             xmax = xmax, ymin = ymin, ymax = ymax), colour = "gray50", 
                           alpha = 0, size = 0.5) + geom_rect(data = histX, 
                                                              aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = count1), 
                                                              fill = "red", colour = "grey50", alpha = 0.3, 
                                                              size = 0.5)
    }    else {
      hist.out <- hist(data[, X], breaks = 80, plot = FALSE)
      n.hist <- length(hist.out$mids)
      dist <- hist.out$mids[2] - hist.out$mids[1]
      hist.max <- max(hist.out$counts)
      histX <- data.frame(ymin = rep(min(yrange) - maxdiff/5, 
                                     n.hist), ymax = hist.out$counts/hist.max * maxdiff/5 + 
                            min(yrange) - maxdiff/5, xmin = hist.out$mids - 
                            dist/2, xmax = hist.out$mids + dist/2)
      p1 <- p1 + geom_rect(data = histX, aes(xmin = xmin, 
                                             xmax = xmax, ymin = ymin, ymax = ymax), colour = "gray50", 
                           alpha = 0, size = 0.5)
    }
  }
  if (is.null(main) == FALSE) {
    p1 <- p1 + ggtitle(main) + theme(plot.title = element_text(hjust = 0.5, 
                                                               size = 35, lineheight = 0.8, face = "bold"))
  }
  if (is.null(ylim) == FALSE) {
    ylim2 = c(ylim[1] - (ylim[2] - ylim[1]) * 0.25/6, ylim[2] + 
                (ylim[2] - ylim[1]) * 0.4/6)
  }
  if (is.null(xlim) == FALSE & is.null(ylim) == FALSE) {
    p1 <- p1 + coord_cartesian(xlim = xlim, ylim = ylim2)
  }
  if (is.null(xlim) == TRUE & is.null(ylim) == FALSE) {
    p1 <- p1 + coord_cartesian(ylim = ylim2)
  }
  if (is.null(xlim) == FALSE & is.null(ylim) == TRUE) {
    p1 <- p1 + coord_cartesian(xlim = xlim)
  }
  if (is.null(file) == FALSE) {
    pdf(file)
    plot(p1)
    graphics.off()
  }
  output <- list(bw = bw, est = est, graph = p1)
  if (CV == 1) {
    output <- c(output, list(CV.out = cv.out$CV.out))
  }
  return(output)
}
